Skip to content

feat(pt): add hard-coded aparam output gate for fitting nets#5495

Open
Jingbei-Bai wants to merge 8 commits into
deepmodeling:masterfrom
Jingbei-Bai:feat/aparam-output-gate
Open

feat(pt): add hard-coded aparam output gate for fitting nets#5495
Jingbei-Bai wants to merge 8 commits into
deepmodeling:masterfrom
Jingbei-Bai:feat/aparam-output-gate

Conversation

@Jingbei-Bai

@Jingbei-Bai Jingbei-Bai commented Jun 4, 2026

Copy link
Copy Markdown

Summary by CodeRabbit

  • New Features

    • Optional per-atom output gating controlled by three new fitting options; example config updated.
    • New variational-Gaussian averaged smooth descriptor (se_a_vg) and supporting descriptor exports.
    • Descriptors can optionally accept per-atom parameters; model wiring respects that.
  • Bug Fixes / Behavior

    • Output gating applied after output statistics/bias so gated atomic outputs honor aparam inputs.
    • Descriptor behavior unchanged when aparam is unused.
  • Tests

    • Added tests covering gating behavior, formula, serialization, reshape handling, and se_a_vg.

Co-authored-by: Cursor <cursoragent@cursor.com>
Copilot AI review requested due to automatic review settings June 4, 2026 10:58
@dosubot dosubot Bot added the new feature label Jun 4, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds an optional “aparam output gate” to fitting networks (PyTorch + array-API/dpmodel) and exposes it through configuration/argcheck, with tests and an example config update.

Changes:

  • Introduce use_aparam_output_gate, aparam_gate_norm, and aparam_gate_clamp arguments, including validation + serialization.
  • Apply the gate in PT fitting forward paths and dpmodel array-API fitting path.
  • Add a dedicated PT unit test and update an example training input JSON.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
source/tests/pt/model/test_aparam_output_gate.py Adds unit tests validating gating behavior + serialize roundtrip
examples/fparam/train/input_aparam.json Demonstrates new config knobs in an example input
deepmd/utils/argcheck.py Exposes new fitting arguments and documentation strings
deepmd/pt/model/task/sezm_ener.py Extracts raw aparam and applies output gate in SE(Z/M) energy path
deepmd/pt/model/task/fitting.py Implements gate computation/application and wires it into common forward path
deepmd/dpmodel/fitting/general_fitting.py Mirrors gate logic for dpmodel/array-API fitting path

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

aparam_gate_norm=1.0,
aparam_gate_clamp=True,
).to(device)
fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype))
aparam_gate_norm=norm,
aparam_gate_clamp=False,
).to(device)
fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype))

fitting_gate = fitting._compute_aparam_output_gate(aparam)
expected = (a_val * a_val) / (sigma * sigma * norm)
self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype)))
Comment on lines +818 to +823
if aparam.numel() % (nf * self.numb_aparam) != 0:
raise ValueError(
f"input aparam: cannot reshape {list(aparam.shape)} "
f"into ({nf}, nloc, {self.numb_aparam})."
)
aparam_raw = aparam.view([nf, -1, self.numb_aparam])
Comment on lines +698 to +703
if aparam.numel() % (nf * self.numb_aparam) != 0:
raise ValueError(
f"input aparam: cannot reshape {list(aparam.shape)} "
f"into ({nf}, nloc, {self.numb_aparam})."
)
aparam_raw = aparam.view([nf, -1, self.numb_aparam])
self,
aparam_raw: torch.Tensor,
) -> torch.Tensor:
"""Hard-coded gate g = a^2 / (sigma^2 * norm) from raw aparam."""
Comment on lines +774 to +777
if self.numb_aparam > 1:
gate = gate.prod(dim=-1, keepdim=True)
if self.aparam_gate_clamp:
gate = gate.clamp(0.0, 1.0)
@coderabbitai

coderabbitai Bot commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an optional multiplicative "aparam output gate" for fitting outputs and a new variational-Gaussian descriptor (se_a_vg). Implements gate compute/apply helpers in dpmodel and PT, wires gating into atomic output flow and SeZM, adds args/examples/tests, and adds VG descriptor code, exports, and tests.

Changes

Aparam Output Gate Feature

Layer / File(s) Summary
Configuration args, example, and tests
deepmd/utils/argcheck.py, examples/fparam/train/input_aparam.json, source/tests/pt/model/test_aparam_output_gate.py
Adds fitting_aparam_output_gate_args() and inserts the fields into fitting schemas; updates example JSON; adds tests verifying zero-aparam zeroing, formula, reshape handling, and serialize roundtrip.
dpmodel GeneralFitting and atomic hook
deepmd/dpmodel/fitting/general_fitting.py, deepmd/dpmodel/atomic_model/base_atomic_model.py
Adds use_aparam_output_gate, aparam_gate_norm, aparam_gate_clamp to constructor with validation and serialization; reshapes input aparam to aparam_raw in _call_common; adds _compute_aparam_output_gate/_apply_aparam_output_gate and apply_aparam_output_gate_to_atomic_output; applies gate post out-stat via atomic-model helper.
PyTorch GeneralFitting, atomic flow, and SeZM integration
deepmd/pt/model/task/fitting.py, deepmd/pt/model/atomic_model/base_atomic_model.py, deepmd/pt/model/atomic_model/dp_atomic_model.py, deepmd/pt/model/task/sezm_ener.py
Mirrors dpmodel in PT: ctor params/validation/attrs/serialize, gate compute/apply helpers (TorchScript wrapper), _forward_common builds/validates aparam_raw and uses it for aparam_embed, enforces presence when gating enabled, conditionally forwards aparam to descriptors, and applies gate after out-stat; SeZM forward builds aparam_raw and uses normalized aparam_embed.

se_a_vg Variational-Gaussian Descriptor

Layer / File(s) Summary
VG env-mat utilities and tabulation
deepmd/pt/model/descriptor/env_mat_vg.py
Adds VG radial kernels (vg_gaussian_radial_phi, vg_smooth_radial), neighbor sigma gathering, VG env-matrix builder _make_env_mat_vg and normalizer prod_env_mat_vg, plus two-stage tabulation helper tabulate_fusion_se_a_vg.
Descriptor block and wrapper
deepmd/pt/model/descriptor/se_a_vg.py
Adds DescrptBlockSeAVg with per-type embedding networks, stats collection (compute_input_stats), compression enablement, and forward; adds DescrptSeAVg wrapper with registration, forward, compression wiring, serialize/update_sel.
Integration, exports, and args
deepmd/pt/model/descriptor/__init__.py, deepmd/pt/model/descriptor/se_a.py, deepmd/utils/argcheck.py, deepmd/pt/model/model/__init__.py
Expose new descriptor classes, accept/ignore aparam in existing DescrptSeA.forward signature, register se_a_vg args as alias, include se_a_vg in compression options, and broaden get_spin_model sel-expansion to the se_a family.
se_a_vg unit tests
source/tests/pt/model/test_se_a_vg.py
Add tests verifying sigma=0 equivalence to SE, sensitivity to aparam, forward shapes, DPAtomicModel integration, scripted behavior, and optional compression numerical match when fused op exists.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested reviewers

  • wanghan-iapcm
  • njzjz
  • iProzd
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main feature added: a hard-coded aparam output gate for fitting networks in the PyTorch implementation.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@source/tests/pt/model/test_aparam_output_gate.py`:
- Line 73: The assertion compares fitting_gate (which lives on env.DEVICE) to a
CPU tensor; create the expected tensor on the same device to avoid
device-mismatch failures: when constructing torch.tensor(expected, dtype=dtype)
in the test (the line comparing fitting_gate), pass device=env.DEVICE or call
.to(env.DEVICE) so the expected tensor matches fitting_gate's device before
calling torch.allclose.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 73a16adf-a7d6-43fb-adcb-ab9b7391afb8

📥 Commits

Reviewing files that changed from the base of the PR and between fb6ff93 and afcf8b0.

📒 Files selected for processing (6)
  • deepmd/dpmodel/fitting/general_fitting.py
  • deepmd/pt/model/task/fitting.py
  • deepmd/pt/model/task/sezm_ener.py
  • deepmd/utils/argcheck.py
  • examples/fparam/train/input_aparam.json
  • source/tests/pt/model/test_aparam_output_gate.py


fitting_gate = fitting._compute_aparam_output_gate(aparam)
expected = (a_val * a_val) / (sigma * sigma * norm)
self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype)))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Fix expected tensor device in gate-formula assertion.

At Line 73, torch.tensor(expected, dtype=dtype) is created on CPU, while fitting_gate is on env.DEVICE, which can fail on GPU/MPS.

Proposed fix
-        self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype)))
+        self.assertTrue(
+            torch.allclose(
+                fitting_gate,
+                torch.tensor(expected, dtype=dtype, device=device),
+            )
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt/model/test_aparam_output_gate.py` at line 73, The assertion
compares fitting_gate (which lives on env.DEVICE) to a CPU tensor; create the
expected tensor on the same device to avoid device-mismatch failures: when
constructing torch.tensor(expected, dtype=dtype) in the test (the line comparing
fitting_gate), pass device=env.DEVICE or call .to(env.DEVICE) so the expected
tensor matches fitting_gate's device before calling torch.allclose.

).to(device)
fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype))

descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device)
fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype))

descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device)
atype = torch.zeros(nf, nloc, dtype=torch.int64, device=device)
WAbjb1314 and others added 2 commits June 5, 2026 15:40
Gate previously multiplied only the fitting output before apply_out_stat added per-type out_bias, so sigma=0 inference still returned non-zero energy. Apply the gate after out_bias in the atomic model forward path.

Co-authored-by: Cursor <cursoragent@cursor.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/model/task/fitting.py`:
- Around line 807-813: The current validation only checks aparam_raw.shape[-1]
against self.numb_aparam but does not ensure the leading dims match outs, which
can hide shape mismatches; update the block handling aparam in the method (the
aparam_raw conversion and return of _apply_aparam_output_gate) to explicitly
attempt to reshape aparam_raw to (outs.shape[0], outs.shape[1],
self.numb_aparam) and if that reshape is impossible raise a clear ValueError
describing the expected shape (using outs.shape[0], outs.shape[1],
self.numb_aparam), otherwise use the reshaped tensor for the subsequent call to
_apply_aparam_output_gate so broadcasting/multiplication errors are avoided.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: c9aa14cb-8701-4b9f-9544-1e433e241562

📥 Commits

Reviewing files that changed from the base of the PR and between 499b230 and e5c598b.

📒 Files selected for processing (6)
  • deepmd/dpmodel/atomic_model/base_atomic_model.py
  • deepmd/dpmodel/fitting/general_fitting.py
  • deepmd/pt/model/atomic_model/base_atomic_model.py
  • deepmd/pt/model/task/fitting.py
  • deepmd/pt/model/task/sezm_ener.py
  • source/tests/pt/model/test_aparam_output_gate.py
💤 Files with no reviewable changes (1)
  • deepmd/pt/model/task/sezm_ener.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/tests/pt/model/test_aparam_output_gate.py

Comment thread deepmd/pt/model/task/fitting.py
WAbjb1314 and others added 2 commits June 5, 2026 16:47
Introduce variational-Gaussian smooth descriptor (se_a_vg) where sigma
enters the radial kernel and a fifth environment-matrix column per VGM II.
Wire aparam through the atomic model, register argcheck/compression hooks,
and keep fitting output gate unchanged.

Co-authored-by: Cursor <cursoragent@cursor.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/utils/argcheck.py (1)

2285-2290: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Enforce a positive lower bound for aparam_gate_norm.

Line 2285 currently accepts 0/negative values, but this factor is used in a denominator; that can cause divide-by-zero or invalid gate scaling at runtime. Add schema validation here to fail fast.

Proposed fix
         Argument(
             "aparam_gate_norm",
             float,
             optional=True,
             default=1.0,
+            extra_check=lambda x: x > 0.0,
+            extra_check_errmsg="must be greater than 0",
             doc=doc_aparam_gate_norm,
         ),
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/utils/argcheck.py` around lines 2285 - 2290, The schema entry for
"aparam_gate_norm" currently allows zero/negative values which can cause
divide-by-zero; update the argument schema in deepmd/utils/argcheck.py for the
"aparam_gate_norm" field to enforce a strict positive lower bound (e.g., min >
0) or add a validator that raises an error if the provided value is <= 0 so the
check fails fast; locate the schema definition containing "aparam_gate_norm",
adjust its validation rules (or add a custom validator function) to reject
non-positive values while keeping the default=1.0.
🧹 Nitpick comments (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

290-291: ⚡ Quick win

Avoid per-call inspect.signature(...) in the forward hot path.

Line 290 recomputes the signature every batch. Cache whether aparam is supported once in __init__ and reuse it in forward_atomic.

Proposed refactor
 class DPAtomicModel(BaseAtomicModel):
@@
     def __init__(
@@
         self.eval_descriptor_list = []
         self.eval_fitting_last_layer_list = []
+        self._descriptor_accepts_aparam = (
+            "aparam" in inspect.signature(self.descriptor.forward).parameters
+        )
@@
-        if "aparam" in inspect.signature(self.descriptor.forward).parameters:
+        if self._descriptor_accepts_aparam:
             descriptor_kwargs["aparam"] = aparam
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/atomic_model/dp_atomic_model.py` around lines 290 - 291, The
code currently calls inspect.signature(self.descriptor.forward) inside
forward_atomic for every batch; instead determine once in __init__ whether the
descriptor.forward accepts "aparam" (e.g. set self._descriptor_supports_aparam =
"aparam" in inspect.signature(self.descriptor.forward).parameters) and then in
forward_atomic replace the per-call inspect.signature check with that cached
boolean to conditionally add descriptor_kwargs["aparam"] = aparam; update any
initialization path that sets self.descriptor to ensure the cached flag is
computed after descriptor is assigned.
deepmd/pt/model/descriptor/se_a.py (1)

325-365: ⚡ Quick win

Document the aparam parameter in the docstring.

The forward method now accepts an aparam parameter (line 325) but it is not documented in the docstring (lines 333-364). Even though DescrptSeA does not use this parameter (it's immediately deleted on line 365), documenting it helps maintain API clarity and aids developers who may reference this signature.

📝 Suggested docstring addition
         comm_dict
             The data needed for communication for parallel inference.
+        aparam
+            Atomic parameters. Not used by this descriptor; accepted for
+            interface compatibility with aparam-aware descriptors.
 
         Returns
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/se_a.py` around lines 325 - 365, The forward
method signature includes an unused parameter aparam that is deleted
immediately; update the forward docstring to include a brief description of
aparam (type torch.Tensor | None, optional), its purpose or note that it is
accepted for API compatibility and currently unused, and mention that it will be
ignored (or deleted) within DescrptSeA.forward to clarify behavior for callers
and maintainers; locate the method by the forward function in this module and
add the aparam entry to the existing Parameters section.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/model/descriptor/env_mat_vg.py`:
- Line 48: sigma_loc is being cast to nlist.dtype which is integer and truncates
fractional aparam values; remove the cast to nlist.dtype and instead preserve
aparam's floating dtype (or if device alignment is needed, cast only the device:
.to(device=nlist.device) or .to(dtype=aparam.dtype, device=nlist.device)).
Update the assignment for sigma_loc (from aparam[:, :nloc, 0]) to stop
converting to nlist.dtype so sigma retains its fractional values.

In `@deepmd/pt/model/descriptor/se_a_vg.py`:
- Around line 603-609: The serialization currently hardcodes "trainable": True;
change it to emit the actual flag from the descriptor object (e.g., use
obj.trainable or getattr(obj, "trainable", True)) in the serialize() output so a
descriptor created with trainable=False round-trips correctly; update the entry
replacing the literal True with the object's trainable attribute in the block
that builds the dict (the one referencing "`@variables`", "type_map", and
"type_one_side").
- Around line 282-305: The aparam tensor (aparam_t) is only sized for local
atoms but prod_env_mat_vg expects per-atom parameters for the extended atom set
used by nlist; expand aparam to the extended-atom domain before calling
prod_env_mat_vg by mapping local aparam values to extended indices (use the same
mapping that produced extended_atype/extended_coord from
extend_input_and_build_neighbor_list) to create extended_aparam (matching
extended_atype.shape/length) and pass that instead of aparam_t; apply the same
fix at the second call site around prod_env_mat_vg later in the file (the block
at ~378-401) so halo/periodic neighbors use the correct VG parameters.
- Around line 134-140: Normalize sel to a list before computing its length: move
or duplicate the normalization (self.sel = sel if isinstance(sel, list) else
[sel]) so that self.ntypes is set from len(self.sel) (not len(sel)), and then
compute self.sec using np.cumsum(self.sel); update the assignments around
self.ntypes, self.sel, and self.sec in __init__ (or set self.ntypes =
len(self.sel) immediately after the existing self.sel assignment) to avoid
calling len() on an int.

---

Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 2285-2290: The schema entry for "aparam_gate_norm" currently
allows zero/negative values which can cause divide-by-zero; update the argument
schema in deepmd/utils/argcheck.py for the "aparam_gate_norm" field to enforce a
strict positive lower bound (e.g., min > 0) or add a validator that raises an
error if the provided value is <= 0 so the check fails fast; locate the schema
definition containing "aparam_gate_norm", adjust its validation rules (or add a
custom validator function) to reject non-positive values while keeping the
default=1.0.

---

Nitpick comments:
In `@deepmd/pt/model/atomic_model/dp_atomic_model.py`:
- Around line 290-291: The code currently calls
inspect.signature(self.descriptor.forward) inside forward_atomic for every
batch; instead determine once in __init__ whether the descriptor.forward accepts
"aparam" (e.g. set self._descriptor_supports_aparam = "aparam" in
inspect.signature(self.descriptor.forward).parameters) and then in
forward_atomic replace the per-call inspect.signature check with that cached
boolean to conditionally add descriptor_kwargs["aparam"] = aparam; update any
initialization path that sets self.descriptor to ensure the cached flag is
computed after descriptor is assigned.

In `@deepmd/pt/model/descriptor/se_a.py`:
- Around line 325-365: The forward method signature includes an unused parameter
aparam that is deleted immediately; update the forward docstring to include a
brief description of aparam (type torch.Tensor | None, optional), its purpose or
note that it is accepted for API compatibility and currently unused, and mention
that it will be ignored (or deleted) within DescrptSeA.forward to clarify
behavior for callers and maintainers; locate the method by the forward function
in this module and add the aparam entry to the existing Parameters section.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: e1861c11-00bd-407a-bffa-34f107386e4f

📥 Commits

Reviewing files that changed from the base of the PR and between e5c598b and 444982e.

📒 Files selected for processing (8)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pt/model/descriptor/__init__.py
  • deepmd/pt/model/descriptor/env_mat_vg.py
  • deepmd/pt/model/descriptor/se_a.py
  • deepmd/pt/model/descriptor/se_a_vg.py
  • deepmd/pt/model/model/__init__.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/model/test_se_a_vg.py

Comment thread deepmd/pt/model/descriptor/env_mat_vg.py Outdated
Comment on lines +134 to +140
self.ntypes = len(sel)
self.type_one_side = type_one_side
self.seed = seed
self.reinit_exclude(exclude_types)

self.sel = sel if isinstance(sel, list) else [sel]
self.sec = [0, *np.cumsum(self.sel).tolist()]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Normalize scalar sel before calling len().

The public API accepts sel: int | list[int], but line 134 evaluates len(sel) before the scalar case is converted on line 139. sel=64 will fail in __init__.

Suggested fix
-        self.ntypes = len(sel)
+        self.sel = sel if isinstance(sel, list) else [sel]
+        self.ntypes = len(self.sel)
         self.type_one_side = type_one_side
         self.seed = seed
         self.reinit_exclude(exclude_types)
-
-        self.sel = sel if isinstance(sel, list) else [sel]
         self.sec = [0, *np.cumsum(self.sel).tolist()]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.ntypes = len(sel)
self.type_one_side = type_one_side
self.seed = seed
self.reinit_exclude(exclude_types)
self.sel = sel if isinstance(sel, list) else [sel]
self.sec = [0, *np.cumsum(self.sel).tolist()]
self.sel = sel if isinstance(sel, list) else [sel]
self.ntypes = len(self.sel)
self.type_one_side = type_one_side
self.seed = seed
self.reinit_exclude(exclude_types)
self.sec = [0, *np.cumsum(self.sel).tolist()]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 134 - 140, Normalize sel
to a list before computing its length: move or duplicate the normalization
(self.sel = sel if isinstance(sel, list) else [sel]) so that self.ntypes is set
from len(self.sel) (not len(sel)), and then compute self.sec using
np.cumsum(self.sel); update the assignments around self.ntypes, self.sel, and
self.sec in __init__ (or set self.ntypes = len(self.sel) immediately after the
existing self.sel assignment) to avoid calling len() on an int.

Comment thread deepmd/pt/model/descriptor/se_a_vg.py Outdated
Comment on lines +282 to +305
aparam_t = torch.tensor(
aparam_np[ff], dtype=self.prec, device=env.DEVICE
).reshape(1, nloc, 1)
extended_coord, extended_atype, _, nlist = (
extend_input_and_build_neighbor_list(
coord_t,
atype_t,
self.rcut,
self.sel,
mixed_types=False,
box=box_t,
)
)
env_mat, _, _ = prod_env_mat_vg(
extended_coord,
nlist,
extended_atype[:, :nloc],
aparam_t,
self.mean,
torch.ones_like(self.stddev),
self.rcut,
self.rcut_smth,
protection=self.env_protection,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Expand aparam to the extended-atom domain before building the VG env-mat.

Both call sites hand prod_env_mat_vg() an aparam tensor sized only for local atoms, while nlist indexes extended_coord. That makes every periodic/halo neighbor fall back to the zero-filled sigma path, so training stats and runtime descriptors diverge from the intended VG kernel whenever an extended image is selected. Please build an extended_aparam with the same mapping used for the extended coordinates before calling the env-mat helper.

Also applies to: 378-401

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 282 - 305, The aparam
tensor (aparam_t) is only sized for local atoms but prod_env_mat_vg expects
per-atom parameters for the extended atom set used by nlist; expand aparam to
the extended-atom domain before calling prod_env_mat_vg by mapping local aparam
values to extended indices (use the same mapping that produced
extended_atype/extended_coord from extend_input_and_build_neighbor_list) to
create extended_aparam (matching extended_atype.shape/length) and pass that
instead of aparam_t; apply the same fix at the second call site around
prod_env_mat_vg later in the file (the block at ~378-401) so halo/periodic
neighbors use the correct VG parameters.

Comment on lines +603 to +609
"@variables": {
"davg": obj["davg"].detach().cpu().numpy(),
"dstd": obj["dstd"].detach().cpu().numpy(),
},
"type_map": self.type_map,
"trainable": True,
"type_one_side": obj.type_one_side,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Serialize the real trainable flag.

serialize() always emits "trainable": True, so a descriptor created with trainable=False will round-trip as trainable after save/load or compression setup.

Suggested fix
-            "trainable": True,
+            "trainable": obj.trainable,
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 603 - 609, The
serialization currently hardcodes "trainable": True; change it to emit the
actual flag from the descriptor object (e.g., use obj.trainable or getattr(obj,
"trainable", True)) in the serialize() output so a descriptor created with
trainable=False round-trips correctly; update the entry replacing the literal
True with the object's trainable attribute in the block that builds the dict
(the one referencing "`@variables`", "type_map", and "type_one_side").

@njzjz-bot njzjz-bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found two blocking issues in the new VG env-mat path. Both are inline below.

— OpenClaw 2026.5.28 (model: custom-chat-jinzhezeng-group/gpt-5.5)

) -> torch.Tensor:
"""Map per-atom aparam to neighbor-list sigma values."""
nf, _, nnei = nlist.shape
sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype)

@njzjz-bot njzjz-bot Jun 5, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This casts sigma values to nlist.dtype (int64), so fractional aparam/sigma values such as 0.5 become 0 before building sigma_ij. That makes the VG descriptor ignore non-integer sigma values and also explains why the current tests only compare 0 vs 1. Please keep this in the floating dtype/device of aparam instead of the neighbor-list dtype.

device=sigma_loc.device,
)
sigma_ext[:, :nloc] = sigma_loc
index = nlist.reshape(nf, -1)

@njzjz-bot njzjz-bot Jun 5, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nlist can contain -1 padding entries, but this gathers with the raw neighbor list. torch.gather does not accept negative indices, so any padded neighbor list will raise at runtime. _make_env_mat_vg already builds nlist_safe; this path needs the same masking/safe-index handling before gathering sigma values.

WAbjb1314 and others added 2 commits June 9, 2026 09:16
TorchScript-safe descriptor aparam forwarding, gate reshape parity with
dpmodel, and env_mat_vg JIT/stat fixes so frozen.pth matches checkpoint eval.

Co-authored-by: Cursor <cursoragent@cursor.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/model/descriptor/env_mat_vg.py`:
- Around line 66-76: The _gather_neighbor_sigma routine currently leaves
ghost/periodic-image slots in sigma_ext as zeros which makes torch.gather return
zero sigmas for ghost neighbors; change the logic so ghost indices (positions >=
nloc) are populated by the source-atom sigma via the existing mapping (or remap
nlist to local-source indices before gathering) instead of zeros: use mapping to
copy sigma_loc into sigma_ext at ghost positions (or rewrite index = nlist_safe
to map ghost entries to the corresponding local atom indices) so sigma_nei
reflects the mapped source sigma; keep the original device/dtype handling and
return semantics unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 42c29a36-04a3-4ebd-a59a-996c070bf6c1

📥 Commits

Reviewing files that changed from the base of the PR and between 444982e and ca8e74f.

📒 Files selected for processing (6)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pt/model/descriptor/env_mat_vg.py
  • deepmd/pt/model/descriptor/se_a_vg.py
  • deepmd/pt/model/task/fitting.py
  • source/tests/pt/model/test_aparam_output_gate.py
  • source/tests/pt/model/test_se_a_vg.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • source/tests/pt/model/test_aparam_output_gate.py
  • deepmd/pt/model/task/fitting.py
  • deepmd/pt/model/descriptor/se_a_vg.py

Comment on lines +66 to +76
sigma_ext = torch.zeros(
(nf, nall + 1),
dtype=sigma_loc.dtype,
device=sigma_loc.device,
)
sigma_ext[:, :nloc] = sigma_loc
nlist_safe = torch.where(mask, nlist, nall).to(torch.int64)
index = nlist_safe.reshape(nf, -1)
sigma_nei = torch.gather(sigma_ext, 1, index)
sigma_nei = sigma_nei.view(nf, nloc, nnei)
return torch.where(mask, sigma_nei, torch.zeros_like(sigma_nei))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if mapping is used elsewhere for ghost->local lookup in VG code
rg -n "mapping" deepmd/pt/model/descriptor/se_a_vg.py deepmd/pt/model/descriptor/env_mat_vg.py

# Check test fixture to see if ghost atoms are present
ast-grep --pattern 'class TestCaseSingleFrameWithNlist'

Repository: deepmodeling/deepmd-kit

Length of output: 19223


🌐 Web query:

deepmd-kit variational Gaussian descriptor sigma ghost atoms

💡 Result:

In the context of DeePMD-kit, the terms variational Gaussian descriptor, sigma, and ghost atoms do not constitute a single, standard feature or configuration parameter set within the primary, documented descriptor types (such as se_e2_a, se_atten, or dpa4) [1][2][3][4]. The official documentation focuses on Deep Potential (DP) descriptors that typically rely on smooth embedding networks [1][5][6]. Regarding the specific terms you mentioned: 1. Ghost atoms: While DeePMD-kit handles variable numbers of atoms via padding, this is generally referred to in the documentation as padding with "virtual particles" [1][4] or within the context of specific masking descriptors (e.g., se_a_mask) [1][7]. 2. Gaussian: Gaussian-based basis functions are available as an alternative to Bessel-like basis functions for radial distance representation in certain advanced descriptors like DPA4/SeZM [3]. 3. Sigma: This parameter is commonly associated with Gaussian width in radial basis functions (RBFs) in various machine learning potentials, but it is not a primary configuration keyword in the standard DeePMD-kit input files for the mainstream se_e2_a/se_atten descriptors [2][6]. If you are referencing a specific research implementation, custom extension, or an experimental branch (such as those exploring Gaussian-based descriptors or Gaussian-centered local connectivity), please consult the specific source code or paper associated with that development, as these terms are not part of the standard DeePMD-kit API [8][5]. top_results: [1][2][3][6]

Citations:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect env_mat_vg.py around the cited lines (56-90-ish) and the full _gather_neighbor_sigma definition.
FILE="deepmd/pt/model/descriptor/env_mat_vg.py"
echo "=== Context around lines 50-110 in $FILE ==="
sed -n '1,140p' "$FILE" | nl -ba | sed -n '45,110p'

echo
echo "=== Find _gather_neighbor_sigma definition ==="
rg -n "_gather_neighbor_sigma" "$FILE" || true
# Print a wider window around any match
rg -n "_gather_neighbor_sigma" "$FILE" -n | cut -d: -f1 | while read -r start; do
  s=$((start-30)); e=$((start+80));
  echo "--- $FILE:${s}-${e} ---"
  sed -n "${s},${e}p" "$FILE" | nl -ba
done

echo
echo "=== Search for sigma_ext usage in $FILE ==="
rg -n "sigma_ext" "$FILE" || true

echo
echo "=== Search for nlist_safe / torch.gather patterns in $FILE ==="
rg -n "nlist_safe|torch\.gather\(sigma_ext" "$FILE" || true

echo
echo "=== Inspect call site in se_a_vg.py where env_mat_vg is invoked ==="
SE="deepmd/pt/model/descriptor/se_a_vg.py"
rg -n "prod_env_mat_vg|env_mat_vg|_gather_neighbor_sigma" "$SE" || true

# Print likely relevant sections around where prod_env_mat_vg is called.
# Use a coarse window around matches.
for m in $(rg -n "prod_env_mat_vg" "$SE" | cut -d: -f1 | head -n 5); do
  s=$((m-40)); e=$((m+140));
  echo "--- $SE:${s}-${e} ---"
  sed -n "${s},${e}p" "$SE" | nl -ba
done

echo
echo "=== Check whether mapping is used to remap nlist / indices before gather ==="
rg -n "mapping.*nlist|nlist.*mapping|index.*mapping|reshape.*mapping" "$SE" || true

Repository: deepmodeling/deepmd-kit

Length of output: 208


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="deepmd/pt/model/descriptor/env_mat_vg.py"
SE="deepmd/pt/model/descriptor/se_a_vg.py"

# Simple line-numbered printer without nl
print_range () {
  local f="$1" s="$2" e="$3"
  awk -v s="$s" -v e="$e" '{ if(NR>=s && NR<=e) printf "%6d | %s\n", NR, $0 }' "$f"
}

echo "=== Basic file stats ==="
wc -l "$FILE"
wc -l "$SE"

echo
echo "=== Show env_mat_vg.py around lines 40-110 (target snippet is ~66-76) ==="
print_range "$FILE" 40 110

echo
echo "=== Locate _gather_neighbor_sigma and print its full definition area ==="
rg -n "def _gather_neighbor_sigma" "$FILE" || true
# If found, print a generous window around the first match
start=$(rg -n "def _gather_neighbor_sigma" "$FILE" | head -n1 | cut -d: -f1 || true)
if [ -n "${start:-}" ]; then
  print_range "$FILE" $((start-20)) $((start+140))
fi

echo
echo "=== Search sigma_ext / nlist_safe usage ==="
rg -n "sigma_ext|nlist_safe|torch\.gather\(sigma_ext" "$FILE" || true

echo
echo "=== Inspect call site in se_a_vg.py that invokes env_mat_vg / prod_env_mat_vg ==="
rg -n "prod_env_mat_vg|env_mat_vg" "$SE" || true
prod_line=$(rg -n "prod_env_mat_vg" "$SE" | head -n1 | cut -d: -f1 || true)
if [ -n "${prod_line:-}" ]; then
  print_range "$SE" $((prod_line-60)) $((prod_line+180))
fi

echo
echo "=== Check whether mapping is applied to nlist/indices before env_mat_vg gather ==="
rg -n "mapping|nlist.*mapping|index.*mapping|gather\(.*mapping|remap" "$SE" || true
rg -n "mapping" "$FILE" || true

echo
echo "=== Show deletion/discarding of mapping in se_a_vg.py (around where mapping is referenced) ==="
rg -n "del .*mapping|mapping=" "$SE" || true

Repository: deepmodeling/deepmd-kit

Length of output: 18233


Fix ghost-neighbor sigma gathering for periodic VG environments

_gather_neighbor_sigma in deepmd/pt/model/descriptor/env_mat_vg.py builds sigma_ext with sigma_ext[:, :nloc] = sigma_loc, leaving indices >= nloc (ghost/periodic-image entries that can appear in nlist) as zero; torch.gather(sigma_ext, ...) therefore returns sigma_neighbor=0 for those neighbors. Since _make_env_mat_vg computes sigma_ij = sqrt(sigma_loc^2 + sigma_neighbor^2), ghost neighbors become sigma_ij = sigma_loc and never combine with the ghost’s source-atom sigma. If ghost neighbors are intended to inherit sigma from their mapped local source, fill sigma_ext using the mapping (or remap nlist indices) instead of leaving ghost slots at zero; otherwise document that ghost-neighbor sigma is intentionally zero.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/env_mat_vg.py` around lines 66 - 76, The
_gather_neighbor_sigma routine currently leaves ghost/periodic-image slots in
sigma_ext as zeros which makes torch.gather return zero sigmas for ghost
neighbors; change the logic so ghost indices (positions >= nloc) are populated
by the source-atom sigma via the existing mapping (or remap nlist to
local-source indices before gathering) instead of zeros: use mapping to copy
sigma_loc into sigma_ext at ghost positions (or rewrite index = nlist_safe to
map ghost entries to the corresponding local atom indices) so sigma_nei reflects
the mapped source sigma; keep the original device/dtype handling and return
semantics unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants